{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lJz6FDU1lRzc"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install unidecode\n",
    "!pip install matplotlib>=3.3.2\n",
    "\n",
    "## Install NeMo\n",
    "BRANCH = 'main'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
    "\n",
    "## Grab the config we'll use in this example\n",
    "!mkdir configs\n",
    "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml\n",
    "\n",
    "\"\"\"\n",
    "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n",
    "Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\n",
    "that you want to use the \"Run All Cells\" (or similar) option.\n",
    "\"\"\"\n",
    "# exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v1Jk9etFlRzf"
   },
   "source": [
    "# Telephony speech (8 kHz)\n",
    "This notebook covers general recommendations for using NeMo models with 8 kHz speech. All the pretrained models currently available through NeMo are trained with audio at 16 kHz. This means that if the original audio was sampled at a different rate, it's sampling rate was converted to 16 kHz through upsampling or downsampling. One of the common applications for ASR is to recognize telephony speech which typically consists of speech sampled at 8 kHz.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixed sample rate\n",
    "Most of the pretrained English models distributed with NeMo are trained with mixed sample rate data, i.e. the training data typically consists of data sampled at both 8 kHz and 16 kHz. As an example pretrained Citrinet model \"stt_en_citrinet_1024\" was trained with the following datasets. \n",
    "* Librispeech 960 hours of English speech\n",
    "* Fisher Corpus\n",
    "* Switchboard-1 Dataset\n",
    "* WSJ-0 and WSJ-1\n",
    "* National Speech Corpus - 1\n",
    "* Mozilla Common Voice\n",
    "\n",
    "Among these, Fisher and Switchboard datasets are conversational telephone speech datasets with audio sampled at 8 kHz while the other datasets were originally sampled at least 16 kHz. Before training, all audio files from Fisher and Switchboard datasets were upsampled to 16 kHz. Because of this mixed sample rate training, our models can be used to recognize both narrowband (8kHz) and wideband speech (16kHz)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with NeMo\n",
    "NeMo ASR currently supports inference of audio in .wav format. Internally, the audio file is resampled to 16 kHz before inference is called on the model, so there is no difference running inference on 8 kHz audio compared to say 16 kHz or any other sampling rate audio with NeMo. Let's look at an example for running inference on 8 kHz audio. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is where the an4/ directory will be placed.\n",
    "# Change this if you don't want the data to be extracted in the current directory.\n",
    "data_dir = '.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "import subprocess\n",
    "import tarfile\n",
    "import wget\n",
    "\n",
    "# Download the dataset. This will take a few moments...\n",
    "print(\"******\")\n",
    "if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n",
    "    an4_url = 'http://www.speech.cs.cmu.edu/databases/an4/an4_sphere.tar.gz'\n",
    "    an4_path = wget.download(an4_url, data_dir)\n",
    "    print(f\"Dataset downloaded at: {an4_path}\")\n",
    "else:\n",
    "    print(\"Tarfile already exists.\")\n",
    "    an4_path = data_dir + '/an4_sphere.tar.gz'\n",
    "\n",
    "if not os.path.exists(data_dir + '/an4/'):\n",
    "    # Untar and convert .sph to .wav (using sox)\n",
    "    tar = tarfile.open(an4_path)\n",
    "    tar.extractall(path=data_dir)\n",
    "\n",
    "    print(\"Converting .sph to .wav...\")\n",
    "    sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)\n",
    "    for sph_path in sph_list:\n",
    "        wav_path = sph_path[:-4] + '.wav'\n",
    "        cmd = [\"sox\", sph_path, wav_path]\n",
    "        subprocess.run(cmd)\n",
    "print(\"Finished conversion.\\n******\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Audio in an4 dataset is sampled at 22 kHz. Let's first downsample an audio file to 16 kHz."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "import IPython.display as ipd\n",
    "import librosa.display\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Load and listen to the audio file\n",
    "example_file = data_dir + '/an4/wav/an4_clstk/mgah/cen2-mgah-b.wav'\n",
    "audio, sample_rate = librosa.load(example_file)\n",
    "print(sample_rate)\n",
    "audio_16kHz = librosa.core.resample(audio, sample_rate, 16000)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# Get spectrogram using Librosa's Short-Time Fourier Transform (stft)\n",
    "spec = np.abs(librosa.stft(audio_16kHz))\n",
    "spec_db = librosa.amplitude_to_db(spec, ref=np.max)  # Decibels\n",
    "\n",
    "# Use log scale to view frequencies\n",
    "librosa.display.specshow(spec_db, y_axis='log', x_axis='time', sr=16000)\n",
    "plt.colorbar()\n",
    "plt.title('Audio Spectrogram');\n",
    "plt.ylim([0, 8000])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's downsample the audio to 8 kHz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_8kHz = librosa.core.resample(audio, sample_rate, 8000)\n",
    "spec = np.abs(librosa.stft(audio_8kHz))\n",
    "spec_db = librosa.amplitude_to_db(spec, ref=np.max)  # Decibels\n",
    "\n",
    "# Use log scale to view frequencies\n",
    "librosa.display.specshow(spec_db, y_axis='log', x_axis='time',  sr=8000)\n",
    "plt.colorbar()\n",
    "plt.title('Audio Spectrogram');\n",
    "plt.ylim([0, 8000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import soundfile as sf\n",
    "sf.write(data_dir + '/audio_16kHz.wav', audio_16kHz, 16000)\n",
    "sample, sr = librosa.core.load(data_dir + '/audio_16kHz.wav')\n",
    "ipd.Audio(sample, rate=sr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sf.write(data_dir + '/audio_8kHz.wav', audio_8kHz, 8000)\n",
    "sample, sr = librosa.core.load(data_dir + '/audio_8kHz.wav')\n",
    "ipd.Audio(sample, rate=sr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Let's look at inference results using one of the pretrained models on the original, 16 kHz and 8 kHz versions of the example file we chose above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import ASRModel\n",
    "import torch\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(f'cuda:0')\n",
    "asr_model = ASRModel.from_pretrained(model_name='stt_en_citrinet_1024', map_location=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As discussed above, there are no changes required for inference based on the sampling rate of audio and as we see below the pretrained Citrinet model gives accurate transcription even for audio downsampled to 8 Khz."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(asr_model.transcribe(paths2audio_files=[example_file]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_16kHz.wav']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_8kHz.wav']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training / fine-tuning with 8 kHz data\n",
    "For training a model with new 8 kHz data, one could take two approaches. The first approach, **which is recommended**, is to finetune a pretrained 16 kHz model by upsampling all the data to 16 kHz. Note that upsampling offline before training is not necessary but recommended as online upsampling during training is very time consuming and may slow down training significantly. The second approach is to train an 8 kHz model from scratch. **Note**: For the second approach, in our experiments we saw that loading the weights of a 16 kHz model as initialization helps the model to converge faster with better accuracy.\n",
    "\n",
    "To upsample your 8 kHz data to 16 kHz command line tools like sox or ffmpeg are very useful. Here is the command to upsample and audio file using sox:\n",
    "```shell\n",
    "sox input_8k.wav -r 16000 -o output_16k.wav\n",
    "```\n",
    "Now to finetune a pre-trained model with this upsampled data, you can just restore the model weights from the pre-trained model and call trainer with the upsampled data. As an example, here is how one would fine-tune a Citrinet model:\n",
    "```python\n",
    "python examples/asr/script_to_bpe.py \\\n",
    "    --config-path=\"examples/asr/conf/citrinet\" \\\n",
    "    --config-name=\"citrinet_512.yaml\" \\\n",
    "    model.train_ds.manifest_filepath=\"<path to manifest file with upsampled 16kHz data>\" \\\n",
    "    model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
    "    trainer.gpus=-1 \\\n",
    "    trainer.max_epochs=50 \\\n",
    "    +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
    "```\n",
    "\n",
    "To train an 8 kHz model, just change the sample rate in the config to 8000 as follows:\n",
    "\n",
    "```python\n",
    "python examples/asr/script_to_bpe.py \\\n",
    "    --config-path=\"examples/asr/conf/citrinet\" \\\n",
    "    --config-name=\"citrinet_512.yaml\" \\\n",
    "    model.sample_rate=8000 \\\n",
    "    model.train_ds.manifest_filepath=\"<path to manifest file with 8kHz data>\" \\\n",
    "    model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
    "    trainer.gpus=-1 \\\n",
    "    trainer.max_epochs=50 \\\n",
    "    +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "ASR_with_NeMo.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
